package com.chatplus.application.config.ws.interceptor;

import cn.dev33.satoken.stp.StpUtil;
import cn.hutool.http.HtmlUtil;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.domain.request.ws.ChatWebSocketRequest;
import com.chatplus.application.web.satoken.helper.LoginHelper;
import com.chatplus.application.web.satoken.model.LoginUser;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;

import java.util.Map;
import java.util.Objects;

/**
 * WebSocket握手请求的拦截器
 */
@Component
public class PlusWebSocketInterceptor implements HandshakeInterceptor {

    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(PlusWebSocketInterceptor.class);

    /**
     * 握手前
     *
     * @param request    domain
     * @param response   response
     * @param wsHandler  wsHandler
     * @param attributes attributes
     * @return 是否握手成功
     */
    @Override
    public boolean beforeHandshake(@NotNull ServerHttpRequest request, @NotNull ServerHttpResponse response, @NotNull WebSocketHandler wsHandler, Map<String, Object> attributes) {
        // 提取 GET 参数
        String url = HtmlUtil.unescape(request.getURI().toString());
        UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(url).build();
        MultiValueMap<String, String> queryParams = uriComponents.getQueryParams();
        String path = uriComponents.getPath();
        if(StringUtils.isNotEmpty(path) && path.equals(ChatWebSocketRequest.WS_URL_PATH)){
            String token = queryParams.get("token").getFirst();
            Long userId = checkUserLogin(token);
            if (Objects.isNull(userId)) {
                return false;
            }
            ChatWebSocketRequest chatWebSocketRequest = ChatWebSocketRequest.builder()
                    .sessionId(queryParams.get("session_id").getFirst())
                    .roleId(Long.parseLong(queryParams.get("role_id").getFirst()))
                    .chatId(queryParams.get("chat_id").getFirst())
                    .modelId(Long.parseLong(queryParams.get("model_id").getFirst()))
                    .token(token)
                    .userId(userId)
                    .build();
            attributes.put(ChatWebSocketRequest.WS_URL_PATH, chatWebSocketRequest);
            return true;
        }
        LOGGER.message("WebSocket握手失败，找不到对应的处理器")
                .context("path",path)
                .context("queryParams",queryParams)
                .info();
        return false;
    }

    /**
     * 握手后
     *
     * @param request   domain
     * @param response  response
     * @param wsHandler wsHandler
     * @param exception 异常
     */
    @Override
    public void afterHandshake(@NotNull ServerHttpRequest request, @NotNull ServerHttpResponse response, @NotNull WebSocketHandler wsHandler, Exception exception) {

    }

    private Long checkUserLogin(String token) {
        if (StringUtils.isEmpty(token)) {
            LOGGER.message("WebSocket认证失败,无法访问系统资源").warn();
            return null;
        }
        LoginUser loginUser = LoginHelper.getLoginUser(token);
        if (loginUser == null) {
            LOGGER.message("WebSocket认证失败,无法访问系统资源").warn();
            StpUtil.logoutByTokenValue(token);
            return null;
        }
        return loginUser.getUserId();
    }
}
